{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f7f64953",
   "metadata": {},
   "source": [
    "### Training RL Policies using L5Kit Closed-Loop Environment\n",
    "\n",
    "This notebook describes how to train RL policies for self-driving using our gym-compatible closed-loop environment.\n",
    "\n",
    "We will be using [Proximal Policy Optimization (PPO)](https://arxiv.org/abs/1707.06347) algorithm as our reinforcement learning algorithm, as it not only demonstrates remarkable performance but it is also empirically easy to tune.\n",
    "\n",
    "The PPO implementation in this notebook is based on [Stable Baselines3](https://github.com/DLR-RM/stable-baselines3) framework, a popular framework for training RL policies. Note that our environment is also compatible with [RLlib](https://docs.ray.io/en/latest/rllib.html), another popular frameworks for the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "806093d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Download L5 Sample Dataset and install L5Kit\n",
    "import os\n",
    "RunningInCOLAB = 'google.colab' in str(get_ipython())\n",
    "if RunningInCOLAB:\n",
    "    !wget https://raw.githubusercontent.com/lyft/l5kit/master/examples/setup_notebook_colab.sh -q\n",
    "    !sh ./setup_notebook_colab.sh\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = open(\"./dataset_dir.txt\", \"r\").read().strip()\n",
    "else:\n",
    "    os.environ[\"L5KIT_DATA_FOLDER\"] = \"/tmp/level5_data\"\n",
    "    print(\"Not running in Google Colab.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "585b1fe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.common.callbacks import CheckpointCallback\n",
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "from stable_baselines3.common.utils import get_linear_fn\n",
    "from stable_baselines3.common.vec_env import SubprocVecEnv\n",
    "\n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.environment.feature_extractor import CustomFeatureExtractor\n",
    "from l5kit.environment.callbacks import L5KitEvalCallback\n",
    "from l5kit.environment.envs.l5_env import SimulationConfigGym\n",
    "\n",
    "from l5kit.visualization.visualizer.zarr_utils import episode_out_to_visualizer_scene_gym_cle\n",
    "from l5kit.visualization.visualizer.visualizer import visualize\n",
    "from bokeh.io import output_notebook, show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ea81f77",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset is assumed to be on the folder specified\n",
    "# in the L5KIT_DATA_FOLDER environment variable\n",
    "\n",
    "# get environment config\n",
    "env_config_path = '../gym_config.yaml'\n",
    "cfg = load_config_data(env_config_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cead394",
   "metadata": {},
   "source": [
    "### Define Training and Evaluation Environments\n",
    "\n",
    "**Training**: We will be training the PPO policy on episodes of length 32 time-steps. We will have 4 sub-processes (training environments) that will help to parallelize and speeden up episode rollouts. The *SimConfig* dataclass will define the parameters of the episode rollout: like length of episode rollout, whether to use log-replayed agents or simulated agents etc.\n",
    "\n",
    "**Evaluation**: We will evaluate the performance of the PPO policy on the *entire* scene (~248 time-steps)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b04c8313",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train on episodes of length 32 time steps\n",
    "train_eps_length = 32\n",
    "train_envs = 4\n",
    "\n",
    "# Evaluate on entire scene (~248 time steps)\n",
    "eval_eps_length = None\n",
    "eval_envs = 1\n",
    "\n",
    "# make train env\n",
    "train_sim_cfg = SimulationConfigGym()\n",
    "train_sim_cfg.num_simulation_steps = train_eps_length + 1\n",
    "env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, 'sim_cfg': train_sim_cfg}\n",
    "env = make_vec_env(\"L5-CLE-v0\", env_kwargs=env_kwargs, n_envs=train_envs,\n",
    "                   vec_env_cls=SubprocVecEnv, vec_env_kwargs={\"start_method\": \"fork\"})\n",
    "\n",
    "# make eval env\n",
    "validation_sim_cfg = SimulationConfigGym()\n",
    "validation_sim_cfg.num_simulation_steps = None\n",
    "eval_env_kwargs = {'env_config_path': env_config_path, 'use_kinematic': True, \\\n",
    "                   'return_info': True, 'train': False, 'sim_cfg': validation_sim_cfg}\n",
    "eval_env = make_vec_env(\"L5-CLE-v0\", env_kwargs=eval_env_kwargs, n_envs=eval_envs,\n",
    "                        vec_env_cls=SubprocVecEnv, vec_env_kwargs={\"start_method\": \"fork\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dfc8b1a",
   "metadata": {},
   "source": [
    "### Define backbone feature extractor\n",
    "\n",
    "The backbone feature extractor is shared between the policy and the value networks. The feature extractor *simple_gn* is composed of two convolutional networks followed by a fully connected layer, with ReLU activation. The feature extractor output is passed to both the policy and value networks composed of two fully connected layers with tanh activation (SB3 default).\n",
    "\n",
    "We perform **group normalization** after every convolutional layer. Empirically, we found that group normalization performs far superior to batch normalization. This can be attributed to the fact that activation statistics change quickly in on-policy algorithms (PPO is on-policy) while batch-norm learnable parameters can be slow to update causing training issues."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4da5ac39",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A simple 2 Layer CNN architecture with group normalization\n",
    "model_arch = 'simple_gn'\n",
    "features_dim = 128\n",
    "\n",
    "# Custom Feature Extractor backbone\n",
    "policy_kwargs = {\n",
    "    \"features_extractor_class\": CustomFeatureExtractor,\n",
    "    \"features_extractor_kwargs\": {\"features_dim\": features_dim, \"model_arch\": model_arch},\n",
    "    \"normalize_images\": False\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24ec4f23",
   "metadata": {},
   "source": [
    "### Clipping Schedule\n",
    "\n",
    "We linearly decrease the value of the clipping parameter $\\epsilon$ as the PPO training progress as it shows improved training stability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc28b605",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clipping schedule of PPO epsilon parameter\n",
    "start_val = 0.1\n",
    "end_val = 0.01\n",
    "training_progress_ratio = 1.0\n",
    "clip_schedule = get_linear_fn(start_val, end_val, training_progress_ratio)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8596deb6",
   "metadata": {},
   "source": [
    "### Hyperparameters for PPO. \n",
    "\n",
    "For detailed description, refer https://stable-baselines3.readthedocs.io/en/master/_modules/stable_baselines3/ppo/ppo.html#PPO"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "998a927f",
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 3e-4\n",
    "num_rollout_steps = 256\n",
    "gamma = 0.8\n",
    "gae_lambda = 0.9\n",
    "n_epochs = 10\n",
    "seed = 42\n",
    "batch_size = 64\n",
    "tensorboard_log = 'tb_log'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a310190e",
   "metadata": {},
   "source": [
    "### Define the PPO Policy.\n",
    "\n",
    "SB3 provides an easy interface to the define the PPO policy. Note: We do need to tweak appropriate hyperparameters and the custom policy backbone has been defined above.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d474019b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define model\n",
    "model = PPO(\"CnnPolicy\", env, policy_kwargs=policy_kwargs, verbose=1, n_steps=num_rollout_steps,\n",
    "            learning_rate=lr, gamma=gamma, tensorboard_log=tensorboard_log, n_epochs=n_epochs,\n",
    "            clip_range=clip_schedule, batch_size=batch_size, seed=seed, gae_lambda=gae_lambda)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53754180",
   "metadata": {},
   "source": [
    "### Defining Callbacks\n",
    "\n",
    "We can additionally define callbacks to save model checkpoints and evaluate models during training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f19803ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "callback_list = []\n",
    "\n",
    "# Save Model Periodically\n",
    "save_freq = 100000\n",
    "save_path = './logs/'\n",
    "output = 'PPO'\n",
    "checkpoint_callback = CheckpointCallback(save_freq=(save_freq // train_envs), save_path=save_path, \\\n",
    "                                         name_prefix=output)\n",
    "callback_list.append(checkpoint_callback)\n",
    "\n",
    "# Eval Model Periodically\n",
    "eval_freq = 100000\n",
    "n_eval_episodes = 1\n",
    "val_eval_callback = L5KitEvalCallback(eval_env, eval_freq=(eval_freq // train_envs), \\\n",
    "                                      n_eval_episodes=n_eval_episodes, n_eval_envs=eval_envs)\n",
    "callback_list.append(val_eval_callback)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3bda21",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "087f4e09",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 3000\n",
    "model.learn(n_steps, callback=callback_list)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "575327a2",
   "metadata": {},
   "source": [
    "**Voila!** We have a trained PPO policy! Train for larger number of steps for better accuracy. Typical RL algorithms require training atleast 1M steps for good convergence. You can visualize the quantitiative evaluation using tensorboard."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bc491a2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Visualize Tensorboard logs (!! run on local terminal !!)\n",
    "!tensorboard --logdir tb_log"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d24f956c",
   "metadata": {},
   "source": [
    "### Visualize the episode from the environment\n",
    "\n",
    "We can easily visualize the outputs obtained by rolling out episodes in the L5Kit using the Bokeh visualizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcd2b424",
   "metadata": {},
   "outputs": [],
   "source": [
    "rollout_sim_cfg = SimulationConfigGym()\n",
    "rollout_sim_cfg.num_simulation_steps = None\n",
    "rollout_env = gym.make(\"L5-CLE-v0\", env_config_path=env_config_path, sim_cfg=rollout_sim_cfg, \\\n",
    "                       use_kinematic=True, train=False, return_info=True)\n",
    "\n",
    "def rollout_episode(model, env, idx = 0):\n",
    "    \"\"\"Rollout a particular scene index and return the simulation output.\n",
    "\n",
    "    :param model: the RL policy\n",
    "    :param env: the gym environment\n",
    "    :param idx: the scene index to be rolled out\n",
    "    :return: the episode output of the rolled out scene\n",
    "    \"\"\"\n",
    "\n",
    "    # Set the reset_scene_id to 'idx'\n",
    "    env.reset_scene_id = idx\n",
    "    \n",
    "    # Rollout step-by-step\n",
    "    obs = env.reset()\n",
    "    done = False\n",
    "    while True:\n",
    "        action, _ = model.predict(obs, deterministic=True)\n",
    "        obs, _, done, info = env.step(action)\n",
    "        if done:\n",
    "            break\n",
    "\n",
    "    # The episode outputs are present in the key \"sim_outs\"\n",
    "    sim_out = info[\"sim_outs\"][0]\n",
    "    return sim_out\n",
    "\n",
    "# Rollout one episode\n",
    "sim_out = rollout_episode(model, rollout_env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e383cff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# might change with different rasterizer\n",
    "map_API = rollout_env.dataset.rasterizer.sem_rast.mapAPI\n",
    "\n",
    "def visualize_outputs(sim_outs, map_API):\n",
    "    for sim_out in sim_outs: # for each scene\n",
    "        vis_in = episode_out_to_visualizer_scene_gym_cle(sim_out, map_API)\n",
    "        show(visualize(sim_out.scene_id, vis_in))\n",
    "\n",
    "output_notebook()\n",
    "visualize_outputs([sim_out], map_API)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
